{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fdfd2e32",
   "metadata": {},
   "source": [
    "# CUDA-MODE session 4 (ch 4 + 5 of the book)\n",
    "\n",
    "Notebook by Thomas Viehmann, based on Jeremy Howard's notebook from lecture 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fd3f4aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install ninja\n",
    "!sudo apt update\n",
    "!sudo apt install g++-11 -y\n",
    "!sudo apt install ccache -y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "63b545e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.utils.cpp_extension\n",
    "import os\n",
    "os.environ['CXX'] = '/usr/lib/ccache/g++-11'\n",
    "os.environ['CC'] = '/usr/lib/ccache/gcc-11'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "288de6fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using /home/tv/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...\n",
      "Detected CUDA files, patching ldflags\n",
      "Emitting ninja build file /home/tv/.cache/torch_extensions/py310_cu121/test_ext/build.ninja...\n",
      "Building extension module test_ext...\n",
      "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1/3] /usr/lib/ccache/g++-11 -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=test_ext -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1016\\\" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -c /home/tv/.cache/torch_extensions/py310_cu121/test_ext/main.cpp -o main.o \n",
      "[2/3] /usr/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /usr/lib/ccache/gcc-11 -DTORCH_EXTENSION_NAME=test_ext -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1016\\\" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' --ptxas-options=-v -std=c++17 -c /home/tv/.cache/torch_extensions/py310_cu121/test_ext/cuda.cu -o cuda.cuda.o \n",
      "ptxas info    : 1 bytes gmem\n",
      "ptxas info    : Compiling entry function '_Z23rgb_to_grayscale_kernelPhS_i' for 'sm_86'\n",
      "ptxas info    : Function properties for _Z23rgb_to_grayscale_kernelPhS_i\n",
      "    0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads\n",
      "ptxas info    : Used 14 registers, 372 bytes cmem[0]\n",
      "[3/3] /usr/lib/ccache/g++-11 main.o cuda.cuda.o -shared -L/usr/local/lib/python3.10/dist-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/lib64 -lcudart -o test_ext.so\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading extension module test_ext...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "26.8034565 µs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "STAGE:2024-02-03 18:46:05 288753:288753 ActivityProfilerController.cpp:314] Completed Stage: Warm Up\n",
      "STAGE:2024-02-03 18:46:05 288753:288753 ActivityProfilerController.cpp:320] Completed Stage: Collection\n",
      "STAGE:2024-02-03 18:46:05 288753:288753 ActivityProfilerController.cpp:324] Completed Stage: Post Processing\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                       cudaLaunchKernel        12.93%      32.585ms        12.93%      32.585ms       3.259us       0.000us         0.00%       0.000us       0.000us         10000  \n",
      "rgb_to_grayscale_kernel(unsigned char*, unsigned cha...         0.00%       0.000us         0.00%       0.000us       0.000us     269.413ms       100.00%     269.413ms      26.941us         10000  \n",
      "                                  cudaDeviceSynchronize        87.07%     219.449ms        87.07%     219.449ms      21.943us       0.000us         0.00%       0.000us       0.000us         10001  \n",
      "-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "Self CPU time total: 252.034ms\n",
      "Self CUDA time total: 269.413ms\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# based on Jeremy's Lecture 3 notebook\n",
    "cuda_begin = r'''\n",
    "#include <torch/extension.h>\n",
    "#include <stdio.h>\n",
    "#include <c10/cuda/CUDAException.h>\n",
    "\n",
    "#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n",
    "#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n",
    "#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n",
    "\n",
    "inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}\n",
    "'''\n",
    "\n",
    "cuda_src = cuda_begin + r'''\n",
    "__global__ void rgb_to_grayscale_kernel(unsigned char* out, unsigned char* in, int n) {\n",
    "    int i = blockIdx.x*blockDim.x + threadIdx.x;\n",
    "    if (i >= n) return;\n",
    "    out[i] = 0.2989f*in[i] + 0.5870f*in[i+n] + 0.1140f*in[i+2*n];  // fix with f found by Andreas...\n",
    "}\n",
    "\n",
    "torch::Tensor rgb_to_grayscale_out(torch::Tensor output, const torch::Tensor& input) {\n",
    "    CHECK_INPUT(input);\n",
    "    int h = input.size(1);\n",
    "    int w = input.size(2);\n",
    "    TORCH_CHECK((h == output.size(0)) || (w == output.size(1)) || (output.device() == input.device())\n",
    "                || (output.scalar_type() == input.scalar_type()));\n",
    "    int threads = 256;\n",
    "    rgb_to_grayscale_kernel<<<cdiv(w*h,threads), threads>>>(\n",
    "        output.data_ptr<unsigned char>(), input.data_ptr<unsigned char>(), w*h);\n",
    "    C10_CUDA_KERNEL_LAUNCH_CHECK();\n",
    "    return output;\n",
    "}\n",
    "\n",
    "torch::Tensor rgb_to_grayscale(const torch::Tensor& input) {\n",
    "    CHECK_INPUT(input);\n",
    "    int h = input.size(1);\n",
    "    int w = input.size(2);\n",
    "    auto output = torch::empty({h,w}, input.options());\n",
    "    rgb_to_grayscale_out(output, input);\n",
    "    return output;\n",
    "}\n",
    "'''\n",
    "\n",
    "cpp_src = \"\"\"\n",
    "torch::Tensor rgb_to_grayscale(const torch::Tensor& input);\n",
    "torch::Tensor rgb_to_grayscale_out(torch::Tensor outpuit, const torch::Tensor& input);\n",
    "\"\"\"\n",
    "\n",
    "import os\n",
    "os.environ['CXX'] = '/usr/lib/ccache/g++-11'\n",
    "os.environ['CC'] = '/usr/lib/ccache/gcc-11'\n",
    "\n",
    "module = torch.utils.cpp_extension.load_inline(\n",
    "    \"test_ext\", cpp_src, cuda_src, \n",
    "    functions=['rgb_to_grayscale', 'rgb_to_grayscale_out'], extra_cuda_cflags=['--ptxas-options=-v'], verbose=True)\n",
    "\n",
    "\n",
    "\n",
    "n = 2048\n",
    "t = torch.randint(0, 256, (3, n, n), dtype=torch.uint8, device=\"cuda\")\n",
    "out = module.rgb_to_grayscale(t); torch.cuda.synchronize()\n",
    "\n",
    "import time\n",
    "t0 = time.perf_counter_ns()\n",
    "for i in range(10_000):\n",
    "    module.rgb_to_grayscale_out(out, t)\n",
    "torch.cuda.synchronize()\n",
    "t1 = time.perf_counter_ns()\n",
    "\n",
    "print((t1-t0) / 10_000 / 1_000, \"µs\") \n",
    "\n",
    "\n",
    "with torch.profiler.profile() as prof:\n",
    "    for i in range(10_000):\n",
    "        module.rgb_to_grayscale_out(out, t)\n",
    "        torch.cuda.synchronize()\n",
    "\n",
    "print(prof.key_averages().table())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "442400bd",
   "metadata": {},
   "source": [
    "# Approximate gelu as a fusion example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a1c8c558",
   "metadata": {},
   "outputs": [],
   "source": [
    "# as per the pytorch doc, implemented manually\n",
    "def gelu(x):\n",
    "    return 0.5 * x * (1+ torch.tanh((2/torch.pi)**0.5 * (x+0.044715 * x**3)))\n",
    "\n",
    "x = torch.randn(1024, 1024, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f0733f1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gelu(x) - torch.nn.functional.gelu(x, approximate='tanh')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d782da2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "435 µs ± 1.82 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
      "61.6 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit gelu(x); torch.cuda.synchronize()\n",
    "%timeit torch.nn.functional.gelu(x, approximate='tanh'); torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5a9f0de",
   "metadata": {},
   "source": [
    "## Kind of slow. Why?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8a54e12d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using /home/tv/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...\n",
      "Detected CUDA files, patching ldflags\n",
      "Emitting ninja build file /home/tv/.cache/torch_extensions/py310_cu121/test_ext_gelu/build.ninja...\n",
      "Building extension module test_ext_gelu...\n",
      "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1/3] /usr/lib/ccache/g++-11 -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=test_ext_gelu -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1016\\\" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -c /home/tv/.cache/torch_extensions/py310_cu121/test_ext_gelu/main.cpp -o main.o \n",
      "[2/3] /usr/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /usr/lib/ccache/gcc-11 -DTORCH_EXTENSION_NAME=test_ext_gelu -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1016\\\" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' --ptxas-options=-v -std=c++17 -c /home/tv/.cache/torch_extensions/py310_cu121/test_ext_gelu/cuda.cu -o cuda.cuda.o \n",
      "ptxas info    : 1 bytes gmem\n",
      "ptxas info    : Compiling entry function '_Z14my_gelu_kernelPfS_i' for 'sm_86'\n",
      "ptxas info    : Function properties for _Z14my_gelu_kernelPfS_i\n",
      "    0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads\n",
      "ptxas info    : Used 12 registers, 372 bytes cmem[0]\n",
      "[3/3] /usr/lib/ccache/g++-11 main.o cuda.cuda.o -shared -L/usr/local/lib/python3.10/dist-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/lib64 -lcudart -o test_ext_gelu.so\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading extension module test_ext_gelu...\n"
     ]
    }
   ],
   "source": [
    "cuda_src = cuda_begin + r'''\n",
    "__global__ void my_gelu_kernel(float* out, float* inp, int n) {\n",
    "    int i = blockIdx.x*blockDim.x + threadIdx.x;\n",
    "    if (i >= n) return;\n",
    "    float x = inp[i];\n",
    "    out[i] = 0.5f * x * (1.0f+ tanhf(sqrtf(2.0f/3.141592653589793f) * (x+0.044715f * (x*x*x))));\n",
    "}\n",
    "\n",
    "torch::Tensor my_gelu_out(torch::Tensor output, const torch::Tensor& inp) {\n",
    "    CHECK_INPUT(inp);\n",
    "    int n = inp.numel();\n",
    "    TORCH_CHECK((output.sizes() == inp.sizes())  || (output.device() == inp.device())\n",
    "                || (output.scalar_type() == inp.scalar_type()));\n",
    "    int threads = 256;\n",
    "    my_gelu_kernel<<<cdiv(n, threads), threads>>>(\n",
    "        output.data_ptr<float>(), inp.data_ptr<float>(), n);\n",
    "    C10_CUDA_KERNEL_LAUNCH_CHECK();\n",
    "    return output;\n",
    "}\n",
    "\n",
    "torch::Tensor my_gelu(const torch::Tensor& inp) {\n",
    "    CHECK_INPUT(inp);\n",
    "    auto output = torch::empty_like(inp);\n",
    "    my_gelu_out(output, inp);\n",
    "    return output;\n",
    "}\n",
    "'''\n",
    "\n",
    "cpp_src = \"\"\"\n",
    "torch::Tensor my_gelu(const torch::Tensor& inp);\n",
    "torch::Tensor my_gelu_out(torch::Tensor output, const torch::Tensor& inp);\n",
    "\"\"\"\n",
    "\n",
    "import os\n",
    "os.environ['CXX'] = '/usr/lib/ccache/g++-11'\n",
    "os.environ['CC'] = '/usr/lib/ccache/gcc-11'\n",
    "\n",
    "gelu_module = torch.utils.cpp_extension.load_inline(\n",
    "    \"test_ext_gelu\", cpp_src, cuda_src, \n",
    "    functions=['my_gelu', 'my_gelu_out'], extra_cuda_cflags=['--ptxas-options=-v'], verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ec3f1b5f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(2.3842e-07, device='cuda:0')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(gelu_module.my_gelu(x) - gelu(x)).abs().max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e3427d1c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "55.9 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit gelu_module.my_gelu(x); torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82fabbbd",
   "metadata": {},
   "source": [
    "# Empty kernel to measure launch latency\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f6752592",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using /home/tv/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...\n",
      "Detected CUDA files, patching ldflags\n",
      "Emitting ninja build file /home/tv/.cache/torch_extensions/py310_cu121/test_ext_empty/build.ninja...\n",
      "Building extension module test_ext_empty...\n",
      "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1/3] /usr/lib/ccache/g++-11 -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=test_ext_empty -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1016\\\" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -c /home/tv/.cache/torch_extensions/py310_cu121/test_ext_empty/main.cpp -o main.o \n",
      "[2/3] /usr/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /usr/lib/ccache/gcc-11 -DTORCH_EXTENSION_NAME=test_ext_empty -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1016\\\" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' --ptxas-options=-v -std=c++17 -c /home/tv/.cache/torch_extensions/py310_cu121/test_ext_empty/cuda.cu -o cuda.cuda.o \n",
      "ptxas info    : 1 bytes gmem\n",
      "ptxas info    : Compiling entry function '_Z15my_empty_kernelPfS_i' for 'sm_86'\n",
      "ptxas info    : Function properties for _Z15my_empty_kernelPfS_i\n",
      "    0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads\n",
      "ptxas info    : Used 4 registers, 372 bytes cmem[0]\n",
      "[3/3] /usr/lib/ccache/g++-11 main.o cuda.cuda.o -shared -L/usr/local/lib/python3.10/dist-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/lib64 -lcudart -o test_ext_empty.so\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading extension module test_ext_empty...\n"
     ]
    }
   ],
   "source": [
    "cuda_src = cuda_begin + r'''\n",
    "__global__ void my_empty_kernel(float* out, float* inp, int n) {\n",
    "}\n",
    "\n",
    "torch::Tensor my_empty_out(torch::Tensor output, const torch::Tensor& inp) {\n",
    "    CHECK_INPUT(inp);\n",
    "    int n = inp.numel();\n",
    "    TORCH_CHECK((output.sizes() == inp.sizes())  || (output.device() == inp.device())\n",
    "                || (output.scalar_type() == inp.scalar_type()));\n",
    "    int threads = 256;\n",
    "    my_empty_kernel<<<cdiv(n, threads), threads>>>(\n",
    "        output.data_ptr<float>(), inp.data_ptr<float>(), n);\n",
    "    C10_CUDA_KERNEL_LAUNCH_CHECK();\n",
    "    return output;\n",
    "}\n",
    "\n",
    "torch::Tensor my_empty(const torch::Tensor& inp) {\n",
    "    CHECK_INPUT(inp);\n",
    "    auto output = torch::empty_like(inp);\n",
    "    my_empty_out(output, inp);\n",
    "    return output;\n",
    "}\n",
    "'''\n",
    "\n",
    "cpp_src = \"\"\"\n",
    "torch::Tensor my_empty(const torch::Tensor& inp);\n",
    "torch::Tensor my_empty_out(torch::Tensor output, const torch::Tensor& inp);\n",
    "\"\"\"\n",
    "\n",
    "import os\n",
    "os.environ['CXX'] = '/usr/lib/ccache/g++-11'\n",
    "os.environ['CC'] = '/usr/lib/ccache/gcc-11'\n",
    "\n",
    "empty_module = torch.utils.cpp_extension.load_inline(\n",
    "    \"test_ext_empty\", cpp_src, cuda_src, \n",
    "    functions=['my_empty', 'my_empty_out'], extra_cuda_cflags=['--ptxas-options=-v'], verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2111adbb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15.6 µs ± 261 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "STAGE:2024-02-03 18:47:47 288753:288753 ActivityProfilerController.cpp:314] Completed Stage: Warm Up\n",
      "STAGE:2024-02-03 18:47:47 288753:288753 ActivityProfilerController.cpp:320] Completed Stage: Collection\n",
      "STAGE:2024-02-03 18:47:47 288753:288753 ActivityProfilerController.cpp:324] Completed Stage: Post Processing\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  \n",
      "----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "                        cudaLaunchKernel        80.81%      42.183ms        80.81%      42.183ms       4.218us       0.000us         0.00%       0.000us       0.000us         10000  \n",
      "    my_empty_kernel(float*, float*, int)         0.00%       0.000us         0.00%       0.000us       0.000us      30.000ms       100.00%      30.000ms       3.000us         10000  \n",
      "                   cudaDeviceSynchronize        19.19%      10.020ms        19.19%      10.020ms       1.002us       0.000us         0.00%       0.000us       0.000us         10001  \n",
      "----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n",
      "Self CPU time total: 52.203ms\n",
      "Self CUDA time total: 30.000ms\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%timeit empty_module.my_empty_out(x, x); torch.cuda.synchronize()\n",
    "\n",
    "with torch.profiler.profile() as prof:\n",
    "    for i in range(10_000):\n",
    "        empty_module.my_empty_out(x, x)\n",
    "        torch.cuda.synchronize()\n",
    "print(prof.key_averages().table())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "836d9c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Matmul"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "eb5aa965",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using /home/tv/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...\n",
      "Detected CUDA files, patching ldflags\n",
      "Emitting ninja build file /home/tv/.cache/torch_extensions/py310_cu121/test_ext_simple_matmul/build.ninja...\n",
      "Building extension module test_ext_simple_matmul...\n",
      "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ninja: no work to do.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading extension module test_ext_simple_matmul...\n"
     ]
    }
   ],
   "source": [
    "cuda_src = cuda_begin + r'''\n",
    "__global__ void simple_matmul_k(float* m, float* n, float* out, int h, int w, int k) {\n",
    "    int r = blockIdx.y*blockDim.y + threadIdx.y;\n",
    "    int c = blockIdx.x*blockDim.x + threadIdx.x;\n",
    "\n",
    "    if (r>=h || c>=w) return;\n",
    "    float o = 0;\n",
    "    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];\n",
    "    out[r*w+c] = o;\n",
    "}\n",
    "\n",
    "torch::Tensor simple_matmul(const torch::Tensor& m, const torch::Tensor& n) {\n",
    "    CHECK_INPUT(m); CHECK_INPUT(n);\n",
    "    int h = m.size(0);\n",
    "    int w = n.size(1);\n",
    "    int k = m.size(1);\n",
    "    TORCH_CHECK(k==n.size(0), \"Size mismatch!\");\n",
    "    auto output = torch::zeros({h, w}, m.options());\n",
    "\n",
    "    dim3 tpb(16,16);\n",
    "    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));\n",
    "    simple_matmul_k<<<blocks, tpb>>>(\n",
    "        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);\n",
    "    C10_CUDA_KERNEL_LAUNCH_CHECK();\n",
    "    return output;\n",
    "}\n",
    "'''\n",
    "\n",
    "cpp_src = \"\"\"\n",
    "torch::Tensor simple_matmul(const torch::Tensor& m, const torch::Tensor& n);\n",
    "\"\"\"\n",
    "\n",
    "simple_matmul_module = torch.utils.cpp_extension.load_inline(\n",
    "    \"test_ext_simple_matmul\", cpp_src, cuda_src, \n",
    "    functions=['simple_matmul'], extra_cuda_cflags=['--ptxas-options=-v'], verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ba92c3d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "934 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(0.0002, device='cuda:0')"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.randn(1024, 1024, device=\"cuda\")\n",
    "b = torch.randn(1024, 1024, device=\"cuda\")\n",
    "%timeit simple_matmul_module.simple_matmul(a, b)\n",
    "\n",
    "(simple_matmul_module.simple_matmul(a, b) - a@b).abs().max()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dd38a25",
   "metadata": {},
   "source": [
    "## Tiled matmul"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "47b4e127",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using /home/tv/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...\n",
      "Detected CUDA files, patching ldflags\n",
      "Emitting ninja build file /home/tv/.cache/torch_extensions/py310_cu121/test_ext_tiled_matmul/build.ninja...\n",
      "Building extension module test_ext_tiled_matmul...\n",
      "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1/3] /usr/lib/ccache/g++-11 -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=test_ext_tiled_matmul -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1016\\\" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -c /home/tv/.cache/torch_extensions/py310_cu121/test_ext_tiled_matmul/main.cpp -o main.o \n",
      "[2/3] /usr/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /usr/lib/ccache/gcc-11 -DTORCH_EXTENSION_NAME=test_ext_tiled_matmul -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1016\\\" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' --ptxas-options=-v -std=c++17 -c /home/tv/.cache/torch_extensions/py310_cu121/test_ext_tiled_matmul/cuda.cu -o cuda.cuda.o \n",
      "ptxas info    : 1 bytes gmem\n",
      "ptxas info    : Compiling entry function '_Z19tiled_matmul_kernelPfS_S_iii' for 'sm_86'\n",
      "ptxas info    : Function properties for _Z19tiled_matmul_kernelPfS_S_iii\n",
      "    0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads\n",
      "ptxas info    : Used 36 registers, 2048 bytes smem, 388 bytes cmem[0]\n",
      "[3/3] /usr/lib/ccache/g++-11 main.o cuda.cuda.o -shared -L/usr/local/lib/python3.10/dist-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/lib64 -lcudart -o test_ext_tiled_matmul.so\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading extension module test_ext_tiled_matmul...\n"
     ]
    }
   ],
   "source": [
    "cuda_src = cuda_begin + r\"\"\"\n",
    "constexpr int TILE_SIZE = 16;\n",
    "\n",
    "__global__ void tiled_matmul_kernel(float* out, float* M, float* N, int h, int w, int k) {\n",
    "  __shared__ float M_tile[TILE_SIZE][TILE_SIZE];\n",
    "  __shared__ float N_tile[TILE_SIZE][TILE_SIZE];\n",
    "  \n",
    "  // idxes into tile\n",
    "  int ir = threadIdx.y;\n",
    "  int ic = threadIdx.x;\n",
    "  \n",
    "  int r = blockIdx.y * blockDim.y + threadIdx.y;\n",
    "  int c = blockIdx.x * blockDim.x + threadIdx.x;\n",
    "\n",
    "  // note: cannot just exit if we want to do padding!\n",
    "  \n",
    "  float res = 0.0f;\n",
    "  for (int K_tileidx = 0; K_tileidx < (k + TILE_SIZE -1) / TILE_SIZE; K_tileidx++) {\n",
    "    // note how threadIdx.x is the fastes moving bit --> coalesced memory access\n",
    "    M_tile[ir][ic] = (((r < h) && (K_tileidx * TILE_SIZE + ic < k)) ? M[r * k + K_tileidx * TILE_SIZE + ic] : 0.f);\n",
    "    N_tile[ir][ic] = ((((K_tileidx * TILE_SIZE + ir) < k) && (c < w)) ? N[(K_tileidx * TILE_SIZE + ir) * w + c] : 0.f);\n",
    "    //M_tile[ir][ic] = M[r * k + K_tileidx * TILE_SIZE + ic];\n",
    "    //N_tile[ir][ic] = N[(K_tileidx * TILE_SIZE + ir) * w + c];\n",
    "    __syncthreads();\n",
    "    for (int idx = 0; idx < TILE_SIZE; idx++) {\n",
    "       res += M_tile[ir][idx] * N_tile[idx][ic];\n",
    "    }\n",
    "    __syncthreads(); // important! (why?)\n",
    "  }\n",
    "  if ((r < h) && (c < w)) {\n",
    "    out[r * w + c] = res;\n",
    "  }\n",
    "}\n",
    "\n",
    "torch::Tensor tiled_matmul(const torch::Tensor& m, const torch::Tensor& n) {\n",
    "    CHECK_INPUT(m); CHECK_INPUT(n);\n",
    "    int h = m.size(0);\n",
    "    int w = n.size(1);\n",
    "    int k = m.size(1);\n",
    "    TORCH_CHECK(k==n.size(0), \"Size mismatch\");\n",
    "    //TORCH_CHECK((k % TILE_SIZE == 0) && (h % TILE_SIZE == 0) && (w % TILE_SIZE == 0), \"Padding not done\");\n",
    "    auto output = torch::empty({h, w}, m.options());\n",
    "\n",
    "    dim3 tpb(TILE_SIZE, TILE_SIZE);\n",
    "    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));\n",
    "    tiled_matmul_kernel<<<blocks, tpb>>>(\n",
    "        output.data_ptr<float>(), m.data_ptr<float>(), n.data_ptr<float>(), h, w, k);\n",
    "    C10_CUDA_KERNEL_LAUNCH_CHECK();\n",
    "    return output;\n",
    "}\n",
    "\n",
    "\"\"\"\n",
    "cpp_src = \"\"\"\n",
    "torch::Tensor tiled_matmul(const torch::Tensor& m, const torch::Tensor& n);\n",
    "\"\"\"\n",
    "\n",
    "tiled_matmul_module = torch.utils.cpp_extension.load_inline(\n",
    "    \"test_ext_tiled_matmul\", cpp_src, cuda_src, \n",
    "    functions=['tiled_matmul'], extra_cuda_cflags=['--ptxas-options=-v'], verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2b9ceaa8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "707 µs ± 6.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit tiled_matmul_module.tiled_matmul(a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8e0f2627",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(4.5776e-05, device='cuda:0')"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "aa = torch.randn(500, 200, device=\"cuda\")\n",
    "bb = torch.randn(200, 1000, device=\"cuda\")\n",
    "\n",
    "\n",
    "(tiled_matmul_module.tiled_matmul(aa, bb) - aa@bb).abs().max()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68bc3b49",
   "metadata": {},
   "source": [
    "# Occupancy?\n",
    "\n",
    "- shared memory: 64k/2k -> 32\n",
    "- threads: 1536/256 -> 6\n",
    "\n",
    "$\\rightarrow$ we could afford larger tiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50a78320",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "659c02af",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
